File size: 1,814 Bytes
de40e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
"""
AETHER-Micro Helper Functions

재사용 가능한 유틸리티 함수들
"""

import torch


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    Repeat K/V heads for Grouped Query Attention (GQA)

    Args:
        hidden_states: (batch, num_kv_heads, seq_len, head_dim)
        n_rep: num_heads // num_kv_heads

    Returns:
        (batch, num_heads, seq_len, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """
    Apply Rotary Position Embedding to Q and K

    Args:
        q: query states (batch, num_heads, seq_len, head_dim)
        k: key states (batch, num_kv_heads, seq_len, head_dim)
        cos: cosine values (batch, seq_len, head_dim)
        sin: sine values (batch, seq_len, head_dim)
        position_ids: position indices (batch, seq_len)

    Returns:
        q_embed, k_embed: rotated query and key
    """
    # Gather cos/sin based on position_ids
    cos = cos.squeeze(1).squeeze(0)  # (seq_len, head_dim)
    sin = sin.squeeze(1).squeeze(0)
    cos = cos[position_ids].unsqueeze(1)  # (batch, 1, seq_len, head_dim)
    sin = sin[position_ids].unsqueeze(1)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed