TaoNet-pico-T1 / rope.py
Lobakkang's picture
Upload TaoNet model to HuggingFace Hub
2981407 verified
"""Rotary Position Embedding (RoPE) implementation."""
import torch
import torch.nn as nn
import math
class RotaryEmbedding(nn.Module):
"""Rotary position embeddings."""
def __init__(self, dim, scale=40):
super().__init__()
assert dim % 2 == 0, "Dimension must be even for rotary embeddings"
self.dim = dim
self.scale = scale
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len, device):
"""Generate rotary embeddings for sequence."""
t = torch.arange(seq_len, device=device).type_as(self.inv_freq) / self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
"""Rotate half the hidden dims of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary(x, cos, sin):
"""Apply rotary embeddings to input tensor."""
# Handle case where cos/sin may be shorter than x
cos = cos[..., :x.shape[-1]]
sin = sin[..., :x.shape[-1]]
# Split x based on cos dimensions
x_rot = x[..., :cos.shape[-1]]
x_base = x[..., cos.shape[-1]:]
# Apply rotation
x_rot = (x_rot * cos) + (rotate_half(x_rot) * sin)
# Concatenate rotated and base parts
return torch.cat([x_rot, x_base], dim=-1) if x_base.shape[-1] > 0 else x_rot