add support to SDPA attention
Browse files- modeling_siglip2.py +30 -0
modeling_siglip2.py
CHANGED
|
@@ -505,7 +505,37 @@ class Vision_EagerAttention(nn.Module):
|
|
| 505 |
return attn_output, None
|
| 506 |
|
| 507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
VISION_ATTENTION_CLASSES = {
|
|
|
|
| 509 |
'eager': Vision_EagerAttention,
|
| 510 |
'flash_attention_2': Vision_FlashAttention2,
|
| 511 |
}
|
|
|
|
| 505 |
return attn_output, None
|
| 506 |
|
| 507 |
|
| 508 |
+
class Vision_SDPAAttention(nn.Module):
|
| 509 |
+
def __init__(self, config) -> None:
|
| 510 |
+
super().__init__()
|
| 511 |
+
dim, heads = config.hidden_size, config.num_attention_heads
|
| 512 |
+
self.num_heads, self.head_dim = heads, dim // heads
|
| 513 |
+
self.k_proj, self.v_proj, self.q_proj, self.out_proj = [nn.Linear(dim, dim) for _ in range(4)]
|
| 514 |
+
self.dropout = getattr(config, "attention_dropout", 0.0)
|
| 515 |
+
|
| 516 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb=None, position_embeddings=None):
|
| 517 |
+
seq_length = hidden_states.shape[0]
|
| 518 |
+
q, k, v = self.q_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim), self.k_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim), self.v_proj(hidden_states).view(seq_length, self.num_heads, self.head_dim)
|
| 519 |
+
if position_embeddings is None:
|
| 520 |
+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 521 |
+
cos = emb.cos()
|
| 522 |
+
sin = emb.sin()
|
| 523 |
+
else:
|
| 524 |
+
cos, sin = position_embeddings
|
| 525 |
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
| 526 |
+
attention_mask = torch.full([1, 1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype)
|
| 527 |
+
for i in range(1, len(cu_seqlens)):
|
| 528 |
+
attention_mask[..., cu_seqlens[i-1]:cu_seqlens[i], cu_seqlens[i-1]:cu_seqlens[i]] = 0
|
| 529 |
+
|
| 530 |
+
q = q.transpose(0, 1).unsqueeze(0)
|
| 531 |
+
k = k.transpose(0, 1).unsqueeze(0)
|
| 532 |
+
v = v.transpose(0, 1).unsqueeze(0)
|
| 533 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask)
|
| 534 |
+
return self.out_proj(attn_output.squeeze(0).transpose(0, 1).reshape(seq_length, -1).to(hidden_states.dtype)), None
|
| 535 |
+
|
| 536 |
+
|
| 537 |
VISION_ATTENTION_CLASSES = {
|
| 538 |
+
'sdpa': Vision_SDPAAttention,
|
| 539 |
'eager': Vision_EagerAttention,
|
| 540 |
'flash_attention_2': Vision_FlashAttention2,
|
| 541 |
}
|