llm-from-scratch / v2 /usta_embedding.py
Celal Şamil Kartoğlu
v2 implemented
17c0b30
import torch
import torch.nn as nn
def get_rotary_position_encoding(input:torch.Tensor,base=10000,device="cpu"):
batch_size,context_length,dimension = input.shape
assert dimension % 2 == 0, "dimension must be even"
half_dimension = dimension // 2
freqs_indices = torch.arange(0,half_dimension,device=device,dtype=torch.float32)
freqs = 1.0 / (base ** (freqs_indices / dimension))
positions = torch.arange(0,context_length,device=device,dtype=torch.float32).unsqueeze(1)
angles = positions * freqs
sin_angles = torch.sin(angles)
cos_angles = torch.cos(angles)
input_even = input[:,:,:dimension//2]
input_odd = input[:,:,dimension//2:]
input_even_rotated = input_even * cos_angles - input_odd * sin_angles
input_odd_rotated = input_even * sin_angles + input_odd * cos_angles
input_rotated = torch.empty_like(input,device=device)
input_rotated[:,:,:dimension//2] = input_even_rotated
input_rotated[:,:,:dimension//2:] = input_odd_rotated
return input_rotated
class UstaEmbedding(nn.Module):
def __init__(self,vocab_size,embedding_dim,context_length,device):
super().__init__()
# position embedding but not being used in the forward pass
# it is just for educational purposes
#self.pos_embedding = nn.Embedding(context_length,embedding_dim)
#self.get_pos = get_rotary_position_encoding
self.embedding = nn.Embedding(vocab_size,embedding_dim,device=device)
self.get_pos = get_rotary_position_encoding
self.device = device
def forward(self,x):
x = self.embedding(x) # dictionary meaning of the tokens (words)
x = self.get_pos(x,device=self.device) #meaning of the tokens in the sentence according to their position
return x