from typing import Dict, List, Optional, Tuple, Union import numpy as np def embedding_lookup( input_ids: np.ndarray, embedding_weights: np.ndarray, driver=None ) -> np.ndarray: """ Look up embeddings for input tokens. Args: input_ids: Input token indices of shape (batch_size, sequence_length) embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim) driver: Optional hardware driver for optimized lookup Returns: Embedded tokens of shape (batch_size, sequence_length, hidden_dim) """ if driver and hasattr(driver, 'embedding_lookup'): return driver.embedding_lookup(input_ids, embedding_weights) # Fallback to numpy implementation batch_size, seq_length = input_ids.shape hidden_dim = embedding_weights.shape[1] # Reshape input_ids for broadcasting input_ids_reshaped = input_ids.reshape(-1) # Lookup embeddings embeddings = embedding_weights[input_ids_reshaped] # Reshape back to (batch_size, sequence_length, hidden_dim) return embeddings.reshape(batch_size, seq_length, hidden_dim) def add_positional_encoding( embeddings: np.ndarray, max_position: int, hidden_dim: int, dtype: np.dtype = np.float32, driver=None ) -> np.ndarray: """ Add positional encodings to input embeddings. Args: embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim) max_position: Maximum sequence length hidden_dim: Hidden dimension size dtype: Data type for positional encodings driver: Optional hardware driver for optimized computation Returns: Embeddings with positional encoding added """ if driver and hasattr(driver, 'add_positional_encoding'): return driver.add_positional_encoding( embeddings, max_position, hidden_dim, dtype ) # Fallback to numpy implementation batch_size, seq_length, _ = embeddings.shape # Create position indices position = np.arange(seq_length)[:, np.newaxis] div_term = np.exp( np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim) ) # Calculate positional encodings pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype) pos_encoding[:, 0::2] = np.sin(position * div_term) pos_encoding[:, 1::2] = np.cos(position * div_term) # Add batch dimension and add to embeddings pos_encoding = pos_encoding[np.newaxis, :, :] return embeddings + pos_encoding[:, :seq_length, :] class EmbeddingState: def __init__(self, driver, prefix: str): self.driver = driver self.prefix = prefix self.counter = 0 def get_temp_tensor(self, data, name_suffix: str = "") -> str: """Store temporary computation results in driver memory""" name = f"{self.prefix}_temp_{self.counter}_{name_suffix}" self.counter += 1 self.driver.create_tensor(name, data) return name def free_temp_tensor(self, name: str): """Clean up temporary tensors""" if self.driver.tensor_exists(name): self.driver.delete_tensor(name) class Embedding: """ GPU/DB-backed Embedding layer for NLP/graph models. All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM. """ def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02): self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.driver = driver self.prefix = prefix # Create unique names for persistent tensors self.weight_name = f"{prefix}_weight" self.grad_name = f"{prefix}_grad" # Initialize embedding matrix in driver memory if not present if not driver.tensor_exists(self.weight_name): weights = driver.random_normal( (vocab_size, embedding_dim), mean=0.0, std=init_std ) driver.create_tensor(self.weight_name, weights) # Initialize gradient tensor driver.create_tensor( self.grad_name, np.zeros((vocab_size, embedding_dim)) ) def forward( self, indices_name: str, training: bool = True ) -> str: """ All operations in driver memory indices_name: name of tensor containing indices in driver Returns: name of output tensor in driver """ state = EmbeddingState(self.driver, f"{self.prefix}_fwd") # Get shape info from driver indices = self.driver.get_tensor(indices_name) original_shape = indices.shape # Flatten indices in driver memory flat_name = state.get_temp_tensor( indices.reshape(-1), "flat" ) # Gather embeddings in driver memory gathered_name = state.get_temp_tensor( self.driver.gather(self.weight_name, flat_name), "gathered" ) state.free_temp_tensor(flat_name) # Reshape to original dimensions + embedding_dim output_shape = original_shape + (self.embedding_dim,) output_name = state.get_temp_tensor( self.driver.reshape(gathered_name, output_shape), "output" ) state.free_temp_tensor(gathered_name) if training: # Store intermediate results needed for backward self.save_for_backward(indices_name, original_shape) return output_name def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]): """Save tensors needed for backward pass in driver memory""" self.driver.create_tensor( f"{self.prefix}_cache_indices", self.driver.get_tensor(indices_name) ) self.driver.create_tensor( f"{self.prefix}_cache_shape", np.array(shape) ) def backward(self, grad_output_name: str) -> None: """ Compute gradients in driver memory grad_output_name: name of gradient tensor in driver """ state = EmbeddingState(self.driver, f"{self.prefix}_bwd") # Get cached values from driver indices = self.driver.get_tensor(f"{self.prefix}_cache_indices") orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape")) # Reshape gradient to match gathered shape reshaped_grad_name = state.get_temp_tensor( self.driver.reshape(grad_output_name, (-1, self.embedding_dim)), "reshaped_grad" ) # Use scatter_add to accumulate gradients for each index self.driver.scatter_add( self.grad_name, # Accumulate into gradient tensor indices.reshape(-1), # Flattened indices reshaped_grad_name # Reshaped gradients ) state.free_temp_tensor(reshaped_grad_name) # Cleanup cached tensors self.driver.delete_tensor(f"{self.prefix}_cache_indices") self.driver.delete_tensor(f"{self.prefix}_cache_shape") def parameters(self) -> Dict[str, str]: """Return names of parameter tensors in driver""" return { "weight": self.weight_name, "grad": self.grad_name } def zero_grad(self) -> None: """Reset gradients to zero in driver memory""" self.driver.fill(self.grad_name, 0.0)