INV / helium /positional_encoding.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from typing import Optional
class PositionalEncodingState:
def __init__(self, driver, prefix: str):
self.driver = driver
self.prefix = prefix
self.counter = 0
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
"""Store temporary computation results in driver memory"""
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
self.counter += 1
self.driver.create_tensor(name, data)
return name
def free_temp_tensor(self, name: str):
"""Clean up temporary tensors"""
if self.driver.tensor_exists(name):
self.driver.delete_tensor(name)
def sinusoidal_positional_encoding(
seq_len: int,
hidden_dim: int,
driver = None,
prefix: str = "pos_enc"
) -> str:
"""
All computations done in driver memory if driver is provided
Returns: name of positional encoding tensor in driver, or numpy array if no driver
"""
if driver is None:
# Fallback to numpy
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))
pe = np.zeros((seq_len, hidden_dim))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
state = PositionalEncodingState(driver, prefix)
# Create position sequence in driver memory
position_name = state.get_temp_tensor(
np.arange(seq_len)[:, np.newaxis],
"position"
)
# Create division terms in driver memory
log_term = -np.log(10000.0) / hidden_dim
div_indices = np.arange(0, hidden_dim, 2)
div_term_name = state.get_temp_tensor(
np.exp(div_indices * log_term),
"div_term"
)
# Initialize output tensor in driver memory
pe_name = state.get_temp_tensor(
np.zeros((seq_len, hidden_dim)),
"pe"
)
# Compute position * div_term in driver memory
mul_name = state.get_temp_tensor(
driver.matmul(
driver.get_tensor(position_name),
driver.get_tensor(div_term_name).reshape(1, -1)
),
"multiplied"
)
# Compute sin and cos in driver memory
sin_name = state.get_temp_tensor(
driver.sin(mul_name),
"sin"
)
cos_name = state.get_temp_tensor(
driver.cos(mul_name),
"cos"
)
# Place sin and cos values in output tensor
for i in range(0, hidden_dim, 2):
# Even indices get sin values
driver.scatter(pe_name,
np.array([(j, i) for j in range(seq_len)]),
driver.get_tensor(sin_name)[:, i//2])
# Odd indices get cos values
if i + 1 < hidden_dim:
driver.scatter(pe_name,
np.array([(j, i+1) for j in range(seq_len)]),
driver.get_tensor(cos_name)[:, i//2])
# Cleanup intermediate tensors
state.free_temp_tensor(position_name)
state.free_temp_tensor(div_term_name)
state.free_temp_tensor(mul_name)
state.free_temp_tensor(sin_name)
state.free_temp_tensor(cos_name)
# Return final tensor name
return pe_name