|
|
import numpy as np
|
|
|
from typing import Optional
|
|
|
|
|
|
class PositionalEncodingState:
|
|
|
def __init__(self, driver, prefix: str):
|
|
|
self.driver = driver
|
|
|
self.prefix = prefix
|
|
|
self.counter = 0
|
|
|
|
|
|
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
|
|
|
"""Store temporary computation results in driver memory"""
|
|
|
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
|
|
|
self.counter += 1
|
|
|
self.driver.create_tensor(name, data)
|
|
|
return name
|
|
|
|
|
|
def free_temp_tensor(self, name: str):
|
|
|
"""Clean up temporary tensors"""
|
|
|
if self.driver.tensor_exists(name):
|
|
|
self.driver.delete_tensor(name)
|
|
|
|
|
|
def sinusoidal_positional_encoding(
|
|
|
seq_len: int,
|
|
|
hidden_dim: int,
|
|
|
driver = None,
|
|
|
prefix: str = "pos_enc"
|
|
|
) -> str:
|
|
|
"""
|
|
|
All computations done in driver memory if driver is provided
|
|
|
Returns: name of positional encoding tensor in driver, or numpy array if no driver
|
|
|
"""
|
|
|
if driver is None:
|
|
|
|
|
|
position = np.arange(seq_len)[:, np.newaxis]
|
|
|
div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))
|
|
|
pe = np.zeros((seq_len, hidden_dim))
|
|
|
pe[:, 0::2] = np.sin(position * div_term)
|
|
|
pe[:, 1::2] = np.cos(position * div_term)
|
|
|
return pe
|
|
|
|
|
|
state = PositionalEncodingState(driver, prefix)
|
|
|
|
|
|
|
|
|
position_name = state.get_temp_tensor(
|
|
|
np.arange(seq_len)[:, np.newaxis],
|
|
|
"position"
|
|
|
)
|
|
|
|
|
|
|
|
|
log_term = -np.log(10000.0) / hidden_dim
|
|
|
div_indices = np.arange(0, hidden_dim, 2)
|
|
|
div_term_name = state.get_temp_tensor(
|
|
|
np.exp(div_indices * log_term),
|
|
|
"div_term"
|
|
|
)
|
|
|
|
|
|
|
|
|
pe_name = state.get_temp_tensor(
|
|
|
np.zeros((seq_len, hidden_dim)),
|
|
|
"pe"
|
|
|
)
|
|
|
|
|
|
|
|
|
mul_name = state.get_temp_tensor(
|
|
|
driver.matmul(
|
|
|
driver.get_tensor(position_name),
|
|
|
driver.get_tensor(div_term_name).reshape(1, -1)
|
|
|
),
|
|
|
"multiplied"
|
|
|
)
|
|
|
|
|
|
|
|
|
sin_name = state.get_temp_tensor(
|
|
|
driver.sin(mul_name),
|
|
|
"sin"
|
|
|
)
|
|
|
cos_name = state.get_temp_tensor(
|
|
|
driver.cos(mul_name),
|
|
|
"cos"
|
|
|
)
|
|
|
|
|
|
|
|
|
for i in range(0, hidden_dim, 2):
|
|
|
|
|
|
driver.scatter(pe_name,
|
|
|
np.array([(j, i) for j in range(seq_len)]),
|
|
|
driver.get_tensor(sin_name)[:, i//2])
|
|
|
|
|
|
if i + 1 < hidden_dim:
|
|
|
driver.scatter(pe_name,
|
|
|
np.array([(j, i+1) for j in range(seq_len)]),
|
|
|
driver.get_tensor(cos_name)[:, i//2])
|
|
|
|
|
|
|
|
|
state.free_temp_tensor(position_name)
|
|
|
state.free_temp_tensor(div_term_name)
|
|
|
state.free_temp_tensor(mul_name)
|
|
|
state.free_temp_tensor(sin_name)
|
|
|
state.free_temp_tensor(cos_name)
|
|
|
|
|
|
|
|
|
return pe_name
|
|
|
|