|
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
def embedding_lookup(
|
|
|
input_ids: np.ndarray,
|
|
|
embedding_weights: np.ndarray,
|
|
|
driver=None
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Look up embeddings for input tokens.
|
|
|
|
|
|
Args:
|
|
|
input_ids: Input token indices of shape (batch_size, sequence_length)
|
|
|
embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim)
|
|
|
driver: Optional hardware driver for optimized lookup
|
|
|
|
|
|
Returns:
|
|
|
Embedded tokens of shape (batch_size, sequence_length, hidden_dim)
|
|
|
"""
|
|
|
if driver and hasattr(driver, 'embedding_lookup'):
|
|
|
return driver.embedding_lookup(input_ids, embedding_weights)
|
|
|
|
|
|
|
|
|
batch_size, seq_length = input_ids.shape
|
|
|
hidden_dim = embedding_weights.shape[1]
|
|
|
|
|
|
|
|
|
input_ids_reshaped = input_ids.reshape(-1)
|
|
|
|
|
|
|
|
|
embeddings = embedding_weights[input_ids_reshaped]
|
|
|
|
|
|
|
|
|
return embeddings.reshape(batch_size, seq_length, hidden_dim)
|
|
|
|
|
|
|
|
|
def add_positional_encoding(
|
|
|
embeddings: np.ndarray,
|
|
|
max_position: int,
|
|
|
hidden_dim: int,
|
|
|
dtype: np.dtype = np.float32,
|
|
|
driver=None
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Add positional encodings to input embeddings.
|
|
|
|
|
|
Args:
|
|
|
embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim)
|
|
|
max_position: Maximum sequence length
|
|
|
hidden_dim: Hidden dimension size
|
|
|
dtype: Data type for positional encodings
|
|
|
driver: Optional hardware driver for optimized computation
|
|
|
|
|
|
Returns:
|
|
|
Embeddings with positional encoding added
|
|
|
"""
|
|
|
if driver and hasattr(driver, 'add_positional_encoding'):
|
|
|
return driver.add_positional_encoding(
|
|
|
embeddings,
|
|
|
max_position,
|
|
|
hidden_dim,
|
|
|
dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
batch_size, seq_length, _ = embeddings.shape
|
|
|
|
|
|
|
|
|
position = np.arange(seq_length)[:, np.newaxis]
|
|
|
div_term = np.exp(
|
|
|
np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype)
|
|
|
pos_encoding[:, 0::2] = np.sin(position * div_term)
|
|
|
pos_encoding[:, 1::2] = np.cos(position * div_term)
|
|
|
|
|
|
|
|
|
pos_encoding = pos_encoding[np.newaxis, :, :]
|
|
|
return embeddings + pos_encoding[:, :seq_length, :]
|
|
|
|
|
|
|
|
|
class EmbeddingState:
|
|
|
def __init__(self, driver, prefix: str):
|
|
|
self.driver = driver
|
|
|
self.prefix = prefix
|
|
|
self.counter = 0
|
|
|
|
|
|
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
|
|
|
"""Store temporary computation results in driver memory"""
|
|
|
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
|
|
|
self.counter += 1
|
|
|
self.driver.create_tensor(name, data)
|
|
|
return name
|
|
|
|
|
|
def free_temp_tensor(self, name: str):
|
|
|
"""Clean up temporary tensors"""
|
|
|
if self.driver.tensor_exists(name):
|
|
|
self.driver.delete_tensor(name)
|
|
|
|
|
|
class Embedding:
|
|
|
"""
|
|
|
GPU/DB-backed Embedding layer for NLP/graph models.
|
|
|
All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM.
|
|
|
"""
|
|
|
def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02):
|
|
|
self.vocab_size = vocab_size
|
|
|
self.embedding_dim = embedding_dim
|
|
|
self.driver = driver
|
|
|
self.prefix = prefix
|
|
|
|
|
|
|
|
|
self.weight_name = f"{prefix}_weight"
|
|
|
self.grad_name = f"{prefix}_grad"
|
|
|
|
|
|
|
|
|
if not driver.tensor_exists(self.weight_name):
|
|
|
weights = driver.random_normal(
|
|
|
(vocab_size, embedding_dim),
|
|
|
mean=0.0,
|
|
|
std=init_std
|
|
|
)
|
|
|
driver.create_tensor(self.weight_name, weights)
|
|
|
|
|
|
driver.create_tensor(
|
|
|
self.grad_name,
|
|
|
np.zeros((vocab_size, embedding_dim))
|
|
|
)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
indices_name: str,
|
|
|
training: bool = True
|
|
|
) -> str:
|
|
|
"""
|
|
|
All operations in driver memory
|
|
|
indices_name: name of tensor containing indices in driver
|
|
|
Returns: name of output tensor in driver
|
|
|
"""
|
|
|
state = EmbeddingState(self.driver, f"{self.prefix}_fwd")
|
|
|
|
|
|
|
|
|
indices = self.driver.get_tensor(indices_name)
|
|
|
original_shape = indices.shape
|
|
|
|
|
|
|
|
|
flat_name = state.get_temp_tensor(
|
|
|
indices.reshape(-1),
|
|
|
"flat"
|
|
|
)
|
|
|
|
|
|
|
|
|
gathered_name = state.get_temp_tensor(
|
|
|
self.driver.gather(self.weight_name, flat_name),
|
|
|
"gathered"
|
|
|
)
|
|
|
state.free_temp_tensor(flat_name)
|
|
|
|
|
|
|
|
|
output_shape = original_shape + (self.embedding_dim,)
|
|
|
output_name = state.get_temp_tensor(
|
|
|
self.driver.reshape(gathered_name, output_shape),
|
|
|
"output"
|
|
|
)
|
|
|
state.free_temp_tensor(gathered_name)
|
|
|
|
|
|
if training:
|
|
|
|
|
|
self.save_for_backward(indices_name, original_shape)
|
|
|
|
|
|
return output_name
|
|
|
|
|
|
def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]):
|
|
|
"""Save tensors needed for backward pass in driver memory"""
|
|
|
self.driver.create_tensor(
|
|
|
f"{self.prefix}_cache_indices",
|
|
|
self.driver.get_tensor(indices_name)
|
|
|
)
|
|
|
self.driver.create_tensor(
|
|
|
f"{self.prefix}_cache_shape",
|
|
|
np.array(shape)
|
|
|
)
|
|
|
|
|
|
def backward(self, grad_output_name: str) -> None:
|
|
|
"""
|
|
|
Compute gradients in driver memory
|
|
|
grad_output_name: name of gradient tensor in driver
|
|
|
"""
|
|
|
state = EmbeddingState(self.driver, f"{self.prefix}_bwd")
|
|
|
|
|
|
|
|
|
indices = self.driver.get_tensor(f"{self.prefix}_cache_indices")
|
|
|
orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape"))
|
|
|
|
|
|
|
|
|
reshaped_grad_name = state.get_temp_tensor(
|
|
|
self.driver.reshape(grad_output_name, (-1, self.embedding_dim)),
|
|
|
"reshaped_grad"
|
|
|
)
|
|
|
|
|
|
|
|
|
self.driver.scatter_add(
|
|
|
self.grad_name,
|
|
|
indices.reshape(-1),
|
|
|
reshaped_grad_name
|
|
|
)
|
|
|
|
|
|
state.free_temp_tensor(reshaped_grad_name)
|
|
|
|
|
|
|
|
|
self.driver.delete_tensor(f"{self.prefix}_cache_indices")
|
|
|
self.driver.delete_tensor(f"{self.prefix}_cache_shape")
|
|
|
|
|
|
def parameters(self) -> Dict[str, str]:
|
|
|
"""Return names of parameter tensors in driver"""
|
|
|
return {
|
|
|
"weight": self.weight_name,
|
|
|
"grad": self.grad_name
|
|
|
}
|
|
|
|
|
|
def zero_grad(self) -> None:
|
|
|
"""Reset gradients to zero in driver memory"""
|
|
|
self.driver.fill(self.grad_name, 0.0)
|
|
|
|