alexnasa commited on
Commit
8267b13
·
verified ·
1 Parent(s): bcc5314

Update sam2/modeling/sam/transformer.py

Browse files
Files changed (1) hide show
  1. sam2/modeling/sam/transformer.py +33 -39
sam2/modeling/sam/transformer.py CHANGED
@@ -288,50 +288,44 @@ class RoPEAttention(Attention):
288
  self.freqs_cis = freqs_cis
289
  self.rope_k_repeat = rope_k_repeat
290
 
291
- def forward(
292
- self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
293
- ) -> Tensor:
294
- # Input projections
295
  q = self.q_proj(q)
296
  k = self.k_proj(k)
297
  v = self.v_proj(v)
298
-
299
- # # Separate into heads
300
- # q = self._separate_heads(q, self.num_heads)
301
- # k = self._separate_heads(k, self.num_heads)
302
- # v = self._separate_heads(v, self.num_heads)
303
-
304
- q = rearrange(q, "b s (n d) -> b s n d", n=self.num_heads)
305
- k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
306
- v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
307
-
308
- # Apply rotary position encoding
309
- w = h = math.sqrt(q.shape[-2])
310
- self.freqs_cis = self.freqs_cis.to(q.device)
311
- if self.freqs_cis.shape[0] != q.shape[-2]:
312
- self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
313
- if q.shape[-2] != k.shape[-2]:
314
  assert self.rope_k_repeat
315
-
316
- num_k_rope = k.size(-2) - num_k_exclude_rope
317
- q, k[:, :, :num_k_rope] = apply_rotary_enc(
318
- q,
319
- k[:, :, :num_k_rope],
320
  freqs_cis=self.freqs_cis,
321
  repeat_freqs_k=self.rope_k_repeat,
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- dropout_p = self.dropout_p if self.training else 0.0
325
-
326
- # #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
327
- # out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
328
-
329
- # out = self._recombine_heads(out)
330
-
331
- out = flash_attn_interface.flash_attn_func(q, k, v) # -> [b, s_q, n, d]
332
-
333
- out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
334
-
335
- out = self.out_proj(out)
336
-
337
- return out
 
288
  self.freqs_cis = freqs_cis
289
  self.rope_k_repeat = rope_k_repeat
290
 
291
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
292
+
 
 
293
  q = self.q_proj(q)
294
  k = self.k_proj(k)
295
  v = self.v_proj(v)
296
+
297
+ # 1) reshape to (B, H, S, D) so RoPE sees the sequence at dim -2
298
+ q_hsd = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
299
+ k_hsd = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
300
+ v_hsd = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
301
+
302
+ # 2) RoPE expects S at -2
303
+ S = q_hsd.shape[-2]
304
+ w = h = math.sqrt(S)
305
+ self.freqs_cis = self.freqs_cis.to(q_hsd.device)
306
+ if self.freqs_cis.shape[0] != S:
307
+ self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q_hsd.device)
308
+ if q_hsd.shape[-2] != k_hsd.shape[-2]:
 
 
 
309
  assert self.rope_k_repeat
310
+
311
+ num_k_rope = k_hsd.size(-2) - num_k_exclude_rope
312
+ q_hsd, k_hsd[:, :, :num_k_rope] = apply_rotary_enc(
313
+ q_hsd,
314
+ k_hsd[:, :, :num_k_rope],
315
  freqs_cis=self.freqs_cis,
316
  repeat_freqs_k=self.rope_k_repeat,
317
  )
318
+
319
+ # 3) switch to (B, S, H, D) for FlashAttention
320
+ q_bshd = rearrange(q_hsd, "b h s d -> b s h d")
321
+ k_bshd = rearrange(k_hsd, "b h s d -> b s h d")
322
+ v_bshd = rearrange(v_hsd, "b h s d -> b s h d")
323
+
324
+ out = flash_attn_interface.flash_attn_func(
325
+ q_bshd, k_bshd, v_bshd,
326
+ dropout_p=self.dropout_p if self.training else 0.0
327
+ ) # (B, S, H, D)
328
+
329
+ out = rearrange(out, "b s h d -> b s (h d)")
330
+ return self.out_proj(out)
331