File size: 3,385 Bytes
7a0c684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import numpy as np
from typing import Optional
class PositionalEncodingState:
def __init__(self, driver, prefix: str):
self.driver = driver
self.prefix = prefix
self.counter = 0
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
"""Store temporary computation results in driver memory"""
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
self.counter += 1
self.driver.create_tensor(name, data)
return name
def free_temp_tensor(self, name: str):
"""Clean up temporary tensors"""
if self.driver.tensor_exists(name):
self.driver.delete_tensor(name)
def sinusoidal_positional_encoding(
seq_len: int,
hidden_dim: int,
driver = None,
prefix: str = "pos_enc"
) -> str:
"""
All computations done in driver memory if driver is provided
Returns: name of positional encoding tensor in driver, or numpy array if no driver
"""
if driver is None:
# Fallback to numpy
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))
pe = np.zeros((seq_len, hidden_dim))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
state = PositionalEncodingState(driver, prefix)
# Create position sequence in driver memory
position_name = state.get_temp_tensor(
np.arange(seq_len)[:, np.newaxis],
"position"
)
# Create division terms in driver memory
log_term = -np.log(10000.0) / hidden_dim
div_indices = np.arange(0, hidden_dim, 2)
div_term_name = state.get_temp_tensor(
np.exp(div_indices * log_term),
"div_term"
)
# Initialize output tensor in driver memory
pe_name = state.get_temp_tensor(
np.zeros((seq_len, hidden_dim)),
"pe"
)
# Compute position * div_term in driver memory
mul_name = state.get_temp_tensor(
driver.matmul(
driver.get_tensor(position_name),
driver.get_tensor(div_term_name).reshape(1, -1)
),
"multiplied"
)
# Compute sin and cos in driver memory
sin_name = state.get_temp_tensor(
driver.sin(mul_name),
"sin"
)
cos_name = state.get_temp_tensor(
driver.cos(mul_name),
"cos"
)
# Place sin and cos values in output tensor
for i in range(0, hidden_dim, 2):
# Even indices get sin values
driver.scatter(pe_name,
np.array([(j, i) for j in range(seq_len)]),
driver.get_tensor(sin_name)[:, i//2])
# Odd indices get cos values
if i + 1 < hidden_dim:
driver.scatter(pe_name,
np.array([(j, i+1) for j in range(seq_len)]),
driver.get_tensor(cos_name)[:, i//2])
# Cleanup intermediate tensors
state.free_temp_tensor(position_name)
state.free_temp_tensor(div_term_name)
state.free_temp_tensor(mul_name)
state.free_temp_tensor(sin_name)
state.free_temp_tensor(cos_name)
# Return final tensor name
return pe_name
|