import numpy as np from typing import Optional class PositionalEncodingState: def __init__(self, driver, prefix: str): self.driver = driver self.prefix = prefix self.counter = 0 def get_temp_tensor(self, data, name_suffix: str = "") -> str: """Store temporary computation results in driver memory""" name = f"{self.prefix}_temp_{self.counter}_{name_suffix}" self.counter += 1 self.driver.create_tensor(name, data) return name def free_temp_tensor(self, name: str): """Clean up temporary tensors""" if self.driver.tensor_exists(name): self.driver.delete_tensor(name) def sinusoidal_positional_encoding( seq_len: int, hidden_dim: int, driver = None, prefix: str = "pos_enc" ) -> str: """ All computations done in driver memory if driver is provided Returns: name of positional encoding tensor in driver, or numpy array if no driver """ if driver is None: # Fallback to numpy position = np.arange(seq_len)[:, np.newaxis] div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)) pe = np.zeros((seq_len, hidden_dim)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) return pe state = PositionalEncodingState(driver, prefix) # Create position sequence in driver memory position_name = state.get_temp_tensor( np.arange(seq_len)[:, np.newaxis], "position" ) # Create division terms in driver memory log_term = -np.log(10000.0) / hidden_dim div_indices = np.arange(0, hidden_dim, 2) div_term_name = state.get_temp_tensor( np.exp(div_indices * log_term), "div_term" ) # Initialize output tensor in driver memory pe_name = state.get_temp_tensor( np.zeros((seq_len, hidden_dim)), "pe" ) # Compute position * div_term in driver memory mul_name = state.get_temp_tensor( driver.matmul( driver.get_tensor(position_name), driver.get_tensor(div_term_name).reshape(1, -1) ), "multiplied" ) # Compute sin and cos in driver memory sin_name = state.get_temp_tensor( driver.sin(mul_name), "sin" ) cos_name = state.get_temp_tensor( driver.cos(mul_name), "cos" ) # Place sin and cos values in output tensor for i in range(0, hidden_dim, 2): # Even indices get sin values driver.scatter(pe_name, np.array([(j, i) for j in range(seq_len)]), driver.get_tensor(sin_name)[:, i//2]) # Odd indices get cos values if i + 1 < hidden_dim: driver.scatter(pe_name, np.array([(j, i+1) for j in range(seq_len)]), driver.get_tensor(cos_name)[:, i//2]) # Cleanup intermediate tensors state.free_temp_tensor(position_name) state.free_temp_tensor(div_term_name) state.free_temp_tensor(mul_name) state.free_temp_tensor(sin_name) state.free_temp_tensor(cos_name) # Return final tensor name return pe_name