AETHER-Micro-0.5B / utils.py
Be2Jay's picture
Upload AETHER-Micro 0.5B Phase 1 checkpoint (Step 57000)
de40e7d verified
#!/usr/bin/env python3
"""
AETHER-Micro Helper Functions
재사용 가능한 유틸리티 함수들
"""
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeat K/V heads for Grouped Query Attention (GQA)
Args:
hidden_states: (batch, num_kv_heads, seq_len, head_dim)
n_rep: num_heads // num_kv_heads
Returns:
(batch, num_heads, seq_len, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
"""
Apply Rotary Position Embedding to Q and K
Args:
q: query states (batch, num_heads, seq_len, head_dim)
k: key states (batch, num_kv_heads, seq_len, head_dim)
cos: cosine values (batch, seq_len, head_dim)
sin: sine values (batch, seq_len, head_dim)
position_ids: position indices (batch, seq_len)
Returns:
q_embed, k_embed: rotated query and key
"""
# Gather cos/sin based on position_ids
cos = cos.squeeze(1).squeeze(0) # (seq_len, head_dim)
sin = sin.squeeze(1).squeeze(0)
cos = cos[position_ids].unsqueeze(1) # (batch, 1, seq_len, head_dim)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed