| | |
| | """ |
| | AETHER-Micro Helper Functions |
| | |
| | 재사용 가능한 유틸리티 함수들 |
| | """ |
| |
|
| | import torch |
| |
|
| |
|
| | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| | """ |
| | Repeat K/V heads for Grouped Query Attention (GQA) |
| | |
| | Args: |
| | hidden_states: (batch, num_kv_heads, seq_len, head_dim) |
| | n_rep: num_heads // num_kv_heads |
| | |
| | Returns: |
| | (batch, num_heads, seq_len, head_dim) |
| | """ |
| | batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| | if n_rep == 1: |
| | return hidden_states |
| | hidden_states = hidden_states[:, :, None, :, :].expand( |
| | batch, num_key_value_heads, n_rep, slen, head_dim |
| | ) |
| | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
| |
|
| |
|
| | def rotate_half(x): |
| | """Rotates half the hidden dims of the input.""" |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|
| |
|
| | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| | """ |
| | Apply Rotary Position Embedding to Q and K |
| | |
| | Args: |
| | q: query states (batch, num_heads, seq_len, head_dim) |
| | k: key states (batch, num_kv_heads, seq_len, head_dim) |
| | cos: cosine values (batch, seq_len, head_dim) |
| | sin: sine values (batch, seq_len, head_dim) |
| | position_ids: position indices (batch, seq_len) |
| | |
| | Returns: |
| | q_embed, k_embed: rotated query and key |
| | """ |
| | |
| | cos = cos.squeeze(1).squeeze(0) |
| | sin = sin.squeeze(1).squeeze(0) |
| | cos = cos[position_ids].unsqueeze(1) |
| | sin = sin[position_ids].unsqueeze(1) |
| |
|
| | q_embed = (q * cos) + (rotate_half(q) * sin) |
| | k_embed = (k * cos) + (rotate_half(k) * sin) |
| | return q_embed, k_embed |
| |
|