#!/usr/bin/env python3 """ AETHER-Micro Helper Functions 재사용 가능한 유틸리티 함수들 """ import torch def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ Repeat K/V heads for Grouped Query Attention (GQA) Args: hidden_states: (batch, num_kv_heads, seq_len, head_dim) n_rep: num_heads // num_kv_heads Returns: (batch, num_heads, seq_len, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): """ Apply Rotary Position Embedding to Q and K Args: q: query states (batch, num_heads, seq_len, head_dim) k: key states (batch, num_kv_heads, seq_len, head_dim) cos: cosine values (batch, seq_len, head_dim) sin: sine values (batch, seq_len, head_dim) position_ids: position indices (batch, seq_len) Returns: q_embed, k_embed: rotated query and key """ # Gather cos/sin based on position_ids cos = cos.squeeze(1).squeeze(0) # (seq_len, head_dim) sin = sin.squeeze(1).squeeze(0) cos = cos[position_ids].unsqueeze(1) # (batch, 1, seq_len, head_dim) sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed