INV / helium /embedding.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
def embedding_lookup(
input_ids: np.ndarray,
embedding_weights: np.ndarray,
driver=None
) -> np.ndarray:
"""
Look up embeddings for input tokens.
Args:
input_ids: Input token indices of shape (batch_size, sequence_length)
embedding_weights: Embedding weight matrix of shape (vocab_size, hidden_dim)
driver: Optional hardware driver for optimized lookup
Returns:
Embedded tokens of shape (batch_size, sequence_length, hidden_dim)
"""
if driver and hasattr(driver, 'embedding_lookup'):
return driver.embedding_lookup(input_ids, embedding_weights)
# Fallback to numpy implementation
batch_size, seq_length = input_ids.shape
hidden_dim = embedding_weights.shape[1]
# Reshape input_ids for broadcasting
input_ids_reshaped = input_ids.reshape(-1)
# Lookup embeddings
embeddings = embedding_weights[input_ids_reshaped]
# Reshape back to (batch_size, sequence_length, hidden_dim)
return embeddings.reshape(batch_size, seq_length, hidden_dim)
def add_positional_encoding(
embeddings: np.ndarray,
max_position: int,
hidden_dim: int,
dtype: np.dtype = np.float32,
driver=None
) -> np.ndarray:
"""
Add positional encodings to input embeddings.
Args:
embeddings: Input embeddings of shape (batch_size, sequence_length, hidden_dim)
max_position: Maximum sequence length
hidden_dim: Hidden dimension size
dtype: Data type for positional encodings
driver: Optional hardware driver for optimized computation
Returns:
Embeddings with positional encoding added
"""
if driver and hasattr(driver, 'add_positional_encoding'):
return driver.add_positional_encoding(
embeddings,
max_position,
hidden_dim,
dtype
)
# Fallback to numpy implementation
batch_size, seq_length, _ = embeddings.shape
# Create position indices
position = np.arange(seq_length)[:, np.newaxis]
div_term = np.exp(
np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim)
)
# Calculate positional encodings
pos_encoding = np.zeros((seq_length, hidden_dim), dtype=dtype)
pos_encoding[:, 0::2] = np.sin(position * div_term)
pos_encoding[:, 1::2] = np.cos(position * div_term)
# Add batch dimension and add to embeddings
pos_encoding = pos_encoding[np.newaxis, :, :]
return embeddings + pos_encoding[:, :seq_length, :]
class EmbeddingState:
def __init__(self, driver, prefix: str):
self.driver = driver
self.prefix = prefix
self.counter = 0
def get_temp_tensor(self, data, name_suffix: str = "") -> str:
"""Store temporary computation results in driver memory"""
name = f"{self.prefix}_temp_{self.counter}_{name_suffix}"
self.counter += 1
self.driver.create_tensor(name, data)
return name
def free_temp_tensor(self, name: str):
"""Clean up temporary tensors"""
if self.driver.tensor_exists(name):
self.driver.delete_tensor(name)
class Embedding:
"""
GPU/DB-backed Embedding layer for NLP/graph models.
All weights/tensors are stored and accessed via the driver (e.g., SQLiteMemoryManager), not Python RAM.
"""
def __init__(self, vocab_size: int, embedding_dim: int, driver, prefix: str = "embed", init_std: float = 0.02):
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.driver = driver
self.prefix = prefix
# Create unique names for persistent tensors
self.weight_name = f"{prefix}_weight"
self.grad_name = f"{prefix}_grad"
# Initialize embedding matrix in driver memory if not present
if not driver.tensor_exists(self.weight_name):
weights = driver.random_normal(
(vocab_size, embedding_dim),
mean=0.0,
std=init_std
)
driver.create_tensor(self.weight_name, weights)
# Initialize gradient tensor
driver.create_tensor(
self.grad_name,
np.zeros((vocab_size, embedding_dim))
)
def forward(
self,
indices_name: str,
training: bool = True
) -> str:
"""
All operations in driver memory
indices_name: name of tensor containing indices in driver
Returns: name of output tensor in driver
"""
state = EmbeddingState(self.driver, f"{self.prefix}_fwd")
# Get shape info from driver
indices = self.driver.get_tensor(indices_name)
original_shape = indices.shape
# Flatten indices in driver memory
flat_name = state.get_temp_tensor(
indices.reshape(-1),
"flat"
)
# Gather embeddings in driver memory
gathered_name = state.get_temp_tensor(
self.driver.gather(self.weight_name, flat_name),
"gathered"
)
state.free_temp_tensor(flat_name)
# Reshape to original dimensions + embedding_dim
output_shape = original_shape + (self.embedding_dim,)
output_name = state.get_temp_tensor(
self.driver.reshape(gathered_name, output_shape),
"output"
)
state.free_temp_tensor(gathered_name)
if training:
# Store intermediate results needed for backward
self.save_for_backward(indices_name, original_shape)
return output_name
def save_for_backward(self, indices_name: str, shape: Tuple[int, ...]):
"""Save tensors needed for backward pass in driver memory"""
self.driver.create_tensor(
f"{self.prefix}_cache_indices",
self.driver.get_tensor(indices_name)
)
self.driver.create_tensor(
f"{self.prefix}_cache_shape",
np.array(shape)
)
def backward(self, grad_output_name: str) -> None:
"""
Compute gradients in driver memory
grad_output_name: name of gradient tensor in driver
"""
state = EmbeddingState(self.driver, f"{self.prefix}_bwd")
# Get cached values from driver
indices = self.driver.get_tensor(f"{self.prefix}_cache_indices")
orig_shape = tuple(self.driver.get_tensor(f"{self.prefix}_cache_shape"))
# Reshape gradient to match gathered shape
reshaped_grad_name = state.get_temp_tensor(
self.driver.reshape(grad_output_name, (-1, self.embedding_dim)),
"reshaped_grad"
)
# Use scatter_add to accumulate gradients for each index
self.driver.scatter_add(
self.grad_name, # Accumulate into gradient tensor
indices.reshape(-1), # Flattened indices
reshaped_grad_name # Reshaped gradients
)
state.free_temp_tensor(reshaped_grad_name)
# Cleanup cached tensors
self.driver.delete_tensor(f"{self.prefix}_cache_indices")
self.driver.delete_tensor(f"{self.prefix}_cache_shape")
def parameters(self) -> Dict[str, str]:
"""Return names of parameter tensors in driver"""
return {
"weight": self.weight_name,
"grad": self.grad_name
}
def zero_grad(self) -> None:
"""Reset gradients to zero in driver memory"""
self.driver.fill(self.grad_name, 0.0)