yuboz commited on
Commit
e75957d
·
verified ·
1 Parent(s): 1d8e509

add support to SDPA attention

Browse files
Files changed (1) hide show
  1. modeling_siglip2.py +30 -0
modeling_siglip2.py CHANGED
@@ -505,7 +505,37 @@ class Vision_EagerAttention(nn.Module):
505
  return attn_output, None
506
 
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  VISION_ATTENTION_CLASSES = {
 
509
  'eager': Vision_EagerAttention,
510
  'flash_attention_2': Vision_FlashAttention2,
511
  }
 
505
  return attn_output, None
506
 
507
 
508
+ class Vision_SDPAAttention(nn.Module):
509
+ def __init__(self, config) -> None:
510
+ super().__init__()
511
+ dim, heads = config.hidden_size, config.num_attention_heads
512
+ self.num_heads, self.head_dim = heads, dim // heads
513
+ self.k_proj, self.v_proj, self.q_proj, self.out_proj = [nn.Linear(dim, dim) for _ in range(4)]
514
+ self.dropout = getattr(config, "attention_dropout", 0.0)
515
+
516
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb=None, position_embeddings=None):
517
+ seq_length = hidden_states.shape[0]
518
+ q, k, v = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim), self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim), self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
519
+ if position_embeddings is None:
520
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
521
+ cos = emb.cos()
522
+ sin = emb.sin()
523
+ else:
524
+ cos, sin = position_embeddings
525
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
526
+ attention_mask = torch.full([1, 1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype)
527
+ for i in range(1, len(cu_seqlens)):
528
+ attention_mask[..., cu_seqlens[i-1]:cu_seqlens[i], cu_seqlens[i-1]:cu_seqlens[i]] = 0
529
+
530
+ q = q.transpose(0, 1).unsqueeze(0)
531
+ k = k.transpose(0, 1).unsqueeze(0)
532
+ v = v.transpose(0, 1).unsqueeze(0)
533
+ attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
534
+ return self.out_proj(attn_output.squeeze(0).transpose(0, 1).reshape(seq_length, -1).to(hidden_states.dtype)), None
535
+
536
+
537
  VISION_ATTENTION_CLASSES = {
538
+ 'sdpa': Vision_SDPAAttention,
539
  'eager': Vision_EagerAttention,
540
  'flash_attention_2': Vision_FlashAttention2,
541
  }