alex commited on
Commit
c4d567b
·
1 Parent(s): 7c5fe35

cpu variation

Browse files
Files changed (1) hide show
  1. sam2/modeling/sam/transformer.py +89 -51
sam2/modeling/sam/transformer.py CHANGED
@@ -16,6 +16,7 @@ from torch import Tensor, nn
16
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
  from sam2.modeling.sam2_utils import MLP
18
  from sam2.utils.misc import get_sdp_backends
 
19
  import flash_attn_interface
20
  from einops import rearrange
21
 
@@ -240,28 +241,30 @@ class Attention(nn.Module):
240
  k = self.k_proj(k)
241
  v = self.v_proj(v)
242
 
243
- # # Separate into heads
244
- # q = self._separate_heads(q, self.num_heads)
245
- # k = self._separate_heads(k, self.num_heads)
246
- # v = self._separate_heads(v, self.num_heads)
 
247
 
248
- # dropout_p = self.dropout_p if self.training else 0.0
249
- # # Attention
250
 
251
- # #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
252
- # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
253
 
254
- # out = self._recombine_heads(out)
255
 
256
- q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
257
- k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
258
- v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
 
259
 
260
- out = flash_attn_interface.flash_attn_func(q, k, v) # -> [b, s_q, n, d]
261
 
262
- out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
263
 
264
- out = self.out_proj(out)
265
 
266
  return out
267
 
@@ -288,43 +291,78 @@ class RoPEAttention(Attention):
288
  self.freqs_cis = freqs_cis
289
  self.rope_k_repeat = rope_k_repeat
290
 
291
- def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
292
-
 
 
293
  q = self.q_proj(q)
294
  k = self.k_proj(k)
295
  v = self.v_proj(v)
296
-
297
- # 1) reshape to (B, H, S, D) so RoPE sees the sequence at dim -2
298
- q_hsd = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
299
- k_hsd = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
300
- v_hsd = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
301
-
302
- # 2) RoPE expects S at -2
303
- S = q_hsd.shape[-2]
304
- w = h = math.sqrt(S)
305
- self.freqs_cis = self.freqs_cis.to(q_hsd.device)
306
- if self.freqs_cis.shape[0] != S:
307
- self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q_hsd.device)
308
- if q_hsd.shape[-2] != k_hsd.shape[-2]:
309
- assert self.rope_k_repeat
310
-
311
- num_k_rope = k_hsd.size(-2) - num_k_exclude_rope
312
- q_hsd, k_hsd[:, :, :num_k_rope] = apply_rotary_enc(
313
- q_hsd,
314
- k_hsd[:, :, :num_k_rope],
315
- freqs_cis=self.freqs_cis,
316
- repeat_freqs_k=self.rope_k_repeat,
317
- )
318
-
319
- # 3) switch to (B, S, H, D) for FlashAttention
320
- q_bshd = rearrange(q_hsd, "b h s d -> b s h d")
321
- k_bshd = rearrange(k_hsd, "b h s d -> b s h d")
322
- v_bshd = rearrange(v_hsd, "b h s d -> b s h d")
323
-
324
- out = flash_attn_interface.flash_attn_func(
325
- q_bshd, k_bshd, v_bshd
326
- ) # (B, S, H, D)
327
-
328
- out = rearrange(out, "b s h d -> b s (h d)")
329
- return self.out_proj(out)
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17
  from sam2.modeling.sam2_utils import MLP
18
  from sam2.utils.misc import get_sdp_backends
19
+
20
  import flash_attn_interface
21
  from einops import rearrange
22
 
 
241
  k = self.k_proj(k)
242
  v = self.v_proj(v)
243
 
244
+ if q.device == "cpu":
245
+ # Separate into heads
246
+ q = self._separate_heads(q, self.num_heads)
247
+ k = self._separate_heads(k, self.num_heads)
248
+ v = self._separate_heads(v, self.num_heads)
249
 
250
+ dropout_p = self.dropout_p if self.training else 0.0
251
+ # Attention
252
 
253
+ #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
254
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
255
 
256
+ out = self._recombine_heads(out)
257
 
258
+ else:
259
+ q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
260
+ k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
261
+ v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
262
 
263
+ out = flash_attn_interface.flash_attn_func(q, k, v) # -> [b, s_q, n, d]
264
 
265
+ out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
266
 
267
+ out = self.out_proj(out)
268
 
269
  return out
270
 
 
291
  self.freqs_cis = freqs_cis
292
  self.rope_k_repeat = rope_k_repeat
293
 
294
+ def forward(
295
+ self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
296
+ ) -> Tensor:
297
+ # Input projections
298
  q = self.q_proj(q)
299
  k = self.k_proj(k)
300
  v = self.v_proj(v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ if q.device == "cpu":
303
+ # Separate into heads
304
+ q = self._separate_heads(q, self.num_heads)
305
+ k = self._separate_heads(k, self.num_heads)
306
+ v = self._separate_heads(v, self.num_heads)
307
+
308
+ # Apply rotary position encoding
309
+ w = h = math.sqrt(q.shape[-2])
310
+ self.freqs_cis = self.freqs_cis.to(q.device)
311
+ if self.freqs_cis.shape[0] != q.shape[-2]:
312
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
313
+ if q.shape[-2] != k.shape[-2]:
314
+ assert self.rope_k_repeat
315
+
316
+ num_k_rope = k.size(-2) - num_k_exclude_rope
317
+ q, k[:, :, :num_k_rope] = apply_rotary_enc(
318
+ q,
319
+ k[:, :, :num_k_rope],
320
+ freqs_cis=self.freqs_cis,
321
+ repeat_freqs_k=self.rope_k_repeat,
322
+ )
323
+
324
+ dropout_p = self.dropout_p if self.training else 0.0
325
+
326
+ #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
327
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
328
+
329
+ out = self._recombine_heads(out)
330
+ out = self.out_proj(out)
331
+
332
+ return out
333
+
334
+ else:
335
+
336
+ # 1) reshape to (B, H, S, D) so RoPE sees the sequence at dim -2
337
+ q_hsd = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
338
+ k_hsd = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
339
+ v_hsd = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
340
+
341
+ # 2) RoPE expects S at -2
342
+ S = q_hsd.shape[-2]
343
+ w = h = math.sqrt(S)
344
+ self.freqs_cis = self.freqs_cis.to(q_hsd.device)
345
+ if self.freqs_cis.shape[0] != S:
346
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q_hsd.device)
347
+ if q_hsd.shape[-2] != k_hsd.shape[-2]:
348
+ assert self.rope_k_repeat
349
+
350
+ num_k_rope = k_hsd.size(-2) - num_k_exclude_rope
351
+ q_hsd, k_hsd[:, :, :num_k_rope] = apply_rotary_enc(
352
+ q_hsd,
353
+ k_hsd[:, :, :num_k_rope],
354
+ freqs_cis=self.freqs_cis,
355
+ repeat_freqs_k=self.rope_k_repeat,
356
+ )
357
+
358
+ # 3) switch to (B, S, H, D) for FlashAttention
359
+ q_bshd = rearrange(q_hsd, "b h s d -> b s h d")
360
+ k_bshd = rearrange(k_hsd, "b h s d -> b s h d")
361
+ v_bshd = rearrange(v_hsd, "b h s d -> b s h d")
362
+
363
+ out = flash_attn_interface.flash_attn_func(
364
+ q_bshd, k_bshd, v_bshd
365
+ ) # (B, S, H, D)
366
+
367
+ out = rearrange(out, "b s h d -> b s (h d)")
368
+ return self.out_proj(out)