Upload tiny-random deepseek_v32 model
Browse files- configuration_deepseek_v32.py +12 -0
- model.safetensors +2 -2
- modeling_deepseek_v32.py +174 -31
configuration_deepseek_v32.py
CHANGED
|
@@ -98,6 +98,12 @@ class DeepseekV32Config(PretrainedConfig):
|
|
| 98 |
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 99 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 100 |
The dropout ratio for the attention probabilities.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
```python
|
| 103 |
>>> from transformers import DeepseekV32Model, DeepseekV32Config
|
|
@@ -152,6 +158,9 @@ class DeepseekV32Config(PretrainedConfig):
|
|
| 152 |
rope_scaling=None,
|
| 153 |
attention_bias=False,
|
| 154 |
attention_dropout=0.0,
|
|
|
|
|
|
|
|
|
|
| 155 |
**kwargs,
|
| 156 |
):
|
| 157 |
self.vocab_size = vocab_size
|
|
@@ -192,6 +201,9 @@ class DeepseekV32Config(PretrainedConfig):
|
|
| 192 |
self.rope_scaling = rope_scaling
|
| 193 |
self.attention_bias = attention_bias
|
| 194 |
self.attention_dropout = attention_dropout
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
super().__init__(
|
| 197 |
pad_token_id=pad_token_id,
|
|
|
|
| 98 |
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 99 |
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 100 |
The dropout ratio for the attention probabilities.
|
| 101 |
+
index_n_heads (`int`, *optional*, defaults to 64):
|
| 102 |
+
Number of attention heads used in the sparse attention indexer.
|
| 103 |
+
index_head_dim (`int`, *optional*, defaults to 128):
|
| 104 |
+
Dimension of each head in the sparse attention indexer.
|
| 105 |
+
index_topk (`int`, *optional*, defaults to 2048):
|
| 106 |
+
Number of top-k key-value positions selected by the sparse attention indexer.
|
| 107 |
|
| 108 |
```python
|
| 109 |
>>> from transformers import DeepseekV32Model, DeepseekV32Config
|
|
|
|
| 158 |
rope_scaling=None,
|
| 159 |
attention_bias=False,
|
| 160 |
attention_dropout=0.0,
|
| 161 |
+
index_n_heads=64,
|
| 162 |
+
index_head_dim=128,
|
| 163 |
+
index_topk=2048,
|
| 164 |
**kwargs,
|
| 165 |
):
|
| 166 |
self.vocab_size = vocab_size
|
|
|
|
| 201 |
self.rope_scaling = rope_scaling
|
| 202 |
self.attention_bias = attention_bias
|
| 203 |
self.attention_dropout = attention_dropout
|
| 204 |
+
self.index_n_heads = index_n_heads
|
| 205 |
+
self.index_head_dim = index_head_dim
|
| 206 |
+
self.index_topk = index_topk
|
| 207 |
|
| 208 |
super().__init__(
|
| 209 |
pad_token_id=pad_token_id,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa9a2cfe783c2448f7d7e3d0d2149fe388261955b0c412f1f343a15ba17369e2
|
| 3 |
+
size 546248736
|
modeling_deepseek_v32.py
CHANGED
|
@@ -336,39 +336,39 @@ def rotate_half(x):
|
|
| 336 |
return torch.cat((-x2, x1), dim=-1)
|
| 337 |
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
|
|
|
|
|
|
| 342 |
|
| 343 |
Args:
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 350 |
-
used to pass offsetted position ids when working with a KV-cache.
|
| 351 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 352 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 353 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 354 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 355 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 356 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 357 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 358 |
-
Returns:
|
| 359 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 360 |
"""
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
-
b, h, s, d = k.shape
|
| 368 |
-
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
| 369 |
|
| 370 |
-
|
| 371 |
-
|
|
|
|
|
|
|
| 372 |
return q_embed, k_embed
|
| 373 |
|
| 374 |
|
|
@@ -610,6 +610,128 @@ class DeepseekV32MoE(nn.Module):
|
|
| 610 |
return final_out
|
| 611 |
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 614 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 615 |
"""
|
|
@@ -696,6 +818,9 @@ class DeepseekV32Attention(nn.Module):
|
|
| 696 |
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 697 |
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 698 |
|
|
|
|
|
|
|
|
|
|
| 699 |
def _init_rope(self):
|
| 700 |
if self.config.rope_scaling is None:
|
| 701 |
self.rotary_emb = DeepseekV32RotaryEmbedding(
|
|
@@ -767,8 +892,10 @@ class DeepseekV32Attention(nn.Module):
|
|
| 767 |
|
| 768 |
if self.q_lora_rank is None:
|
| 769 |
q = self.q_proj(hidden_states)
|
|
|
|
| 770 |
else:
|
| 771 |
-
|
|
|
|
| 772 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 773 |
q_nope, q_pe = torch.split(
|
| 774 |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
@@ -823,12 +950,27 @@ class DeepseekV32Attention(nn.Module):
|
|
| 823 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 824 |
f" {attn_weights.size()}"
|
| 825 |
)
|
| 826 |
-
assert attention_mask is not None
|
| 827 |
if attention_mask is not None:
|
| 828 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 829 |
raise ValueError(
|
| 830 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 831 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
attn_weights = attn_weights + attention_mask
|
| 833 |
|
| 834 |
# upcast attention to fp32
|
|
@@ -903,7 +1045,8 @@ class DeepseekV32FlashAttention2(DeepseekV32Attention):
|
|
| 903 |
if self.q_lora_rank is None:
|
| 904 |
q = self.q_proj(hidden_states)
|
| 905 |
else:
|
| 906 |
-
|
|
|
|
| 907 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 908 |
q_nope, q_pe = torch.split(
|
| 909 |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
|
|
| 336 |
return torch.cat((-x2, x1), dim=-1)
|
| 337 |
|
| 338 |
|
| 339 |
+
def apply_rotary_emb(x, cos, sin, position_ids, unsqueeze_dim=1, interleaved=True):
|
| 340 |
+
"""Applies rotary positional embeddings using complex number operations.
|
| 341 |
+
|
| 342 |
+
Matches DeepSeek-V3.2-Exp/inference/model.py apply_rotary_emb:
|
| 343 |
+
Uses view_as_complex / view_as_real for the rotation.
|
| 344 |
|
| 345 |
Args:
|
| 346 |
+
x: Input tensor [batch, heads, seq_len, rope_dim]
|
| 347 |
+
cos, sin: Cached cos/sin values [seq_len, rope_dim]
|
| 348 |
+
position_ids: Position indices [batch, seq_len]
|
| 349 |
+
unsqueeze_dim: Dimension to unsqueeze for broadcasting (default 1 for heads dim)
|
| 350 |
+
interleaved: If True, consecutive pairs are (real, imag). If False, first half real, second half imag.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
"""
|
| 352 |
+
dtype = x.dtype
|
| 353 |
+
shape = x.shape
|
| 354 |
+
half = cos.shape[-1] // 2
|
| 355 |
+
cos_pos = cos[position_ids][..., :half].unsqueeze(unsqueeze_dim)
|
| 356 |
+
sin_pos = sin[position_ids][..., :half].unsqueeze(unsqueeze_dim)
|
| 357 |
+
freqs_cis = torch.complex(cos_pos, sin_pos)
|
| 358 |
+
|
| 359 |
+
if not interleaved:
|
| 360 |
+
x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
|
| 361 |
+
x_complex = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
|
| 362 |
+
y = torch.view_as_real(x_complex * freqs_cis).flatten(-2)
|
| 363 |
+
if not interleaved:
|
| 364 |
+
y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
|
| 365 |
+
return y.to(dtype)
|
| 366 |
|
|
|
|
|
|
|
| 367 |
|
| 368 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 369 |
+
"""Applies Rotary Position Embedding to the query and key tensors (interleaved format)."""
|
| 370 |
+
q_embed = apply_rotary_emb(q, cos, sin, position_ids, unsqueeze_dim, interleaved=True)
|
| 371 |
+
k_embed = apply_rotary_emb(k, cos, sin, position_ids, unsqueeze_dim, interleaved=True)
|
| 372 |
return q_embed, k_embed
|
| 373 |
|
| 374 |
|
|
|
|
| 610 |
return final_out
|
| 611 |
|
| 612 |
|
| 613 |
+
def hadamard_transform(x: torch.Tensor, scale: float) -> torch.Tensor:
|
| 614 |
+
"""Pure PyTorch Hadamard transform via butterfly decomposition."""
|
| 615 |
+
n = x.size(-1)
|
| 616 |
+
h = 1
|
| 617 |
+
while h < n:
|
| 618 |
+
x = x.unflatten(-1, (-1, h * 2))
|
| 619 |
+
a = x[..., :h]
|
| 620 |
+
b = x[..., h:]
|
| 621 |
+
x = torch.cat([a + b, a - b], dim=-1).flatten(-2)
|
| 622 |
+
h *= 2
|
| 623 |
+
return x * scale
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
|
| 627 |
+
"""Applies Hadamard transform to distribute magnitudes evenly across dimensions."""
|
| 628 |
+
hidden_size = x.size(-1)
|
| 629 |
+
return hadamard_transform(x, scale=hidden_size ** -0.5)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
class DeepseekV32Indexer(nn.Module):
|
| 633 |
+
"""
|
| 634 |
+
Sparse attention indexer for DeepSeek V3.2.
|
| 635 |
+
Selects top-k key-value positions to attend to, enabling efficient sparse attention.
|
| 636 |
+
"""
|
| 637 |
+
|
| 638 |
+
def __init__(self, config: DeepseekV32Config):
|
| 639 |
+
super().__init__()
|
| 640 |
+
self.hidden_size = config.hidden_size
|
| 641 |
+
self.n_heads = config.index_n_heads
|
| 642 |
+
self.head_dim = config.index_head_dim
|
| 643 |
+
self.qk_rope_head_dim = config.qk_rope_head_dim
|
| 644 |
+
self.index_topk = config.index_topk
|
| 645 |
+
self.q_lora_rank = config.q_lora_rank
|
| 646 |
+
|
| 647 |
+
# Query projection from compressed q (q_lora_rank) to indexer heads
|
| 648 |
+
self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.head_dim, bias=False)
|
| 649 |
+
# Key projection from hidden states
|
| 650 |
+
self.wk = nn.Linear(self.hidden_size, self.head_dim, bias=False)
|
| 651 |
+
self.k_norm = nn.LayerNorm(self.head_dim)
|
| 652 |
+
# Importance weighting projection
|
| 653 |
+
self.weights_proj = nn.Linear(self.hidden_size, self.n_heads, bias=False)
|
| 654 |
+
|
| 655 |
+
self.softmax_scale = self.head_dim ** -0.5
|
| 656 |
+
|
| 657 |
+
def forward(
|
| 658 |
+
self,
|
| 659 |
+
hidden_states: torch.Tensor,
|
| 660 |
+
compressed_q: torch.Tensor,
|
| 661 |
+
cos: torch.Tensor,
|
| 662 |
+
sin: torch.Tensor,
|
| 663 |
+
position_ids: torch.LongTensor,
|
| 664 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 665 |
+
) -> torch.Tensor:
|
| 666 |
+
"""
|
| 667 |
+
Args:
|
| 668 |
+
hidden_states: Input hidden states [batch, seq_len, hidden_size]
|
| 669 |
+
compressed_q: Compressed query from q_a_layernorm(q_a_proj(x)) [batch, seq_len, q_lora_rank]
|
| 670 |
+
cos, sin: Rotary embedding cos/sin values
|
| 671 |
+
position_ids: Position IDs
|
| 672 |
+
attention_mask: Attention mask [batch, 1, seq_len, seq_len]
|
| 673 |
+
|
| 674 |
+
Returns:
|
| 675 |
+
topk_indices: Indices of top-k positions to attend to [batch, seq_len, index_topk]
|
| 676 |
+
"""
|
| 677 |
+
bsz, q_len, _ = hidden_states.size()
|
| 678 |
+
|
| 679 |
+
# Compute indexer queries
|
| 680 |
+
q = self.wq_b(compressed_q)
|
| 681 |
+
q = q.view(bsz, q_len, self.n_heads, self.head_dim)
|
| 682 |
+
# Split into rope and non-rope parts
|
| 683 |
+
q_pe = q[..., :self.qk_rope_head_dim]
|
| 684 |
+
q_nope = q[..., self.qk_rope_head_dim:]
|
| 685 |
+
|
| 686 |
+
# Apply RoPE to query (non-interleaved in indexer, matching reference)
|
| 687 |
+
q_pe = q_pe.transpose(1, 2) # [bsz, n_heads, q_len, rope_dim]
|
| 688 |
+
q_pe = apply_rotary_emb(q_pe, cos, sin, position_ids, unsqueeze_dim=1, interleaved=False)
|
| 689 |
+
q_pe = q_pe.transpose(1, 2) # back to [bsz, q_len, n_heads, rope_dim]
|
| 690 |
+
|
| 691 |
+
q = torch.cat([q_pe, q_nope], dim=-1) # [bsz, q_len, n_heads, head_dim]
|
| 692 |
+
|
| 693 |
+
# Compute indexer keys
|
| 694 |
+
k = self.wk(hidden_states) # [bsz, q_len, head_dim]
|
| 695 |
+
k = self.k_norm(k)
|
| 696 |
+
k_pe = k[..., :self.qk_rope_head_dim]
|
| 697 |
+
k_nope = k[..., self.qk_rope_head_dim:]
|
| 698 |
+
|
| 699 |
+
# Apply RoPE to key (non-interleaved in indexer, matching reference)
|
| 700 |
+
k_pe = k_pe.unsqueeze(1) # [bsz, 1, q_len, rope_dim]
|
| 701 |
+
k_pe = apply_rotary_emb(k_pe, cos, sin, position_ids, unsqueeze_dim=1, interleaved=False)
|
| 702 |
+
k_pe = k_pe.squeeze(1) # [bsz, q_len, rope_dim]
|
| 703 |
+
|
| 704 |
+
k = torch.cat([k_pe, k_nope], dim=-1) # [bsz, q_len, head_dim]
|
| 705 |
+
|
| 706 |
+
# Apply Hadamard transform (from DeepSeek-V3.2-Exp/inference/model.py)
|
| 707 |
+
q = rotate_activation(q)
|
| 708 |
+
k = rotate_activation(k)
|
| 709 |
+
|
| 710 |
+
# Compute importance weights
|
| 711 |
+
weights = self.weights_proj(hidden_states.float()) * (self.n_heads ** -0.5) # [bsz, q_len, n_heads]
|
| 712 |
+
|
| 713 |
+
# Compute index scores: q @ k^T scaled by weights
|
| 714 |
+
# q: [bsz, q_len, n_heads, head_dim], k: [bsz, q_len, head_dim]
|
| 715 |
+
# scores: [bsz, q_len, q_len] - sum over heads of (q_i @ k_j * weight_i)
|
| 716 |
+
q = q.transpose(1, 2) # [bsz, n_heads, q_len, head_dim]
|
| 717 |
+
k_expanded = k.unsqueeze(1) # [bsz, 1, q_len, head_dim]
|
| 718 |
+
index_score = torch.matmul(q, k_expanded.transpose(-1, -2)) # [bsz, n_heads, q_len, q_len]
|
| 719 |
+
index_score = index_score * self.softmax_scale
|
| 720 |
+
# Weight by importance: weights is [bsz, q_len, n_heads] -> [bsz, n_heads, q_len, 1]
|
| 721 |
+
weights = weights.permute(0, 2, 1).unsqueeze(-1)
|
| 722 |
+
index_score = (index_score * weights).sum(dim=1) # [bsz, q_len, q_len]
|
| 723 |
+
|
| 724 |
+
if attention_mask is not None:
|
| 725 |
+
# attention_mask shape: [bsz, 1, q_len, kv_len]
|
| 726 |
+
index_score = index_score + attention_mask.squeeze(1)
|
| 727 |
+
|
| 728 |
+
# Select top-k indices
|
| 729 |
+
topk = min(self.index_topk, q_len)
|
| 730 |
+
topk_indices = index_score.topk(topk, dim=-1)[1] # [bsz, q_len, topk]
|
| 731 |
+
|
| 732 |
+
return topk_indices
|
| 733 |
+
|
| 734 |
+
|
| 735 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 736 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 737 |
"""
|
|
|
|
| 818 |
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
| 819 |
self.softmax_scale = self.softmax_scale * mscale * mscale
|
| 820 |
|
| 821 |
+
# DeepSeek V3.2 Sparse Attention Indexer
|
| 822 |
+
self.indexer = DeepseekV32Indexer(config)
|
| 823 |
+
|
| 824 |
def _init_rope(self):
|
| 825 |
if self.config.rope_scaling is None:
|
| 826 |
self.rotary_emb = DeepseekV32RotaryEmbedding(
|
|
|
|
| 892 |
|
| 893 |
if self.q_lora_rank is None:
|
| 894 |
q = self.q_proj(hidden_states)
|
| 895 |
+
compressed_q = None
|
| 896 |
else:
|
| 897 |
+
compressed_q = self.q_a_layernorm(self.q_a_proj(hidden_states))
|
| 898 |
+
q = self.q_b_proj(compressed_q)
|
| 899 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 900 |
q_nope, q_pe = torch.split(
|
| 901 |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
|
|
|
| 950 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 951 |
f" {attn_weights.size()}"
|
| 952 |
)
|
|
|
|
| 953 |
if attention_mask is not None:
|
| 954 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 955 |
raise ValueError(
|
| 956 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 957 |
)
|
| 958 |
+
|
| 959 |
+
# DeepSeek V3.2: Apply sparse attention indexer mask (includes causal mask)
|
| 960 |
+
# Matching reference: causal mask is applied only once, via index_mask
|
| 961 |
+
if compressed_q is not None:
|
| 962 |
+
topk_indices = self.indexer(
|
| 963 |
+
hidden_states, compressed_q, cos, sin, position_ids, attention_mask
|
| 964 |
+
)
|
| 965 |
+
# Create sparse index mask: only attend to top-k positions
|
| 966 |
+
index_mask = torch.full(
|
| 967 |
+
(bsz, q_len, kv_seq_len), float("-inf"), device=hidden_states.device
|
| 968 |
+
)
|
| 969 |
+
index_mask.scatter_(-1, topk_indices, 0.0)
|
| 970 |
+
if attention_mask is not None:
|
| 971 |
+
index_mask = index_mask + attention_mask.squeeze(1)
|
| 972 |
+
attn_weights = attn_weights + index_mask.unsqueeze(1)
|
| 973 |
+
elif attention_mask is not None:
|
| 974 |
attn_weights = attn_weights + attention_mask
|
| 975 |
|
| 976 |
# upcast attention to fp32
|
|
|
|
| 1045 |
if self.q_lora_rank is None:
|
| 1046 |
q = self.q_proj(hidden_states)
|
| 1047 |
else:
|
| 1048 |
+
compressed_q = self.q_a_layernorm(self.q_a_proj(hidden_states))
|
| 1049 |
+
q = self.q_b_proj(compressed_q)
|
| 1050 |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
| 1051 |
q_nope, q_pe = torch.split(
|
| 1052 |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|